from DoubleLinkedList import DoubleLinkedList, Node


class LFUNode(Node):
    """ 继承Node节点类，新加属性freq表示频率 """

    def __init__(self, key, value) -> None:
        super(LFUNode, self).__init__(key, value)
        # 频率
        self.freq = 0


class LFUCache():

    def __init__(self, capacity):
        """
        param:  map: 依然保存所有节点的映射，方便查询 
        param:  freq_map: 字典类型。
                    key: 频率， value：频率对应的双链表。
                    当发生淘汰时：
                    1. 选择频率最低的链表
                    2. 选择从链表头部删除（因为尾部添加元素， 头部的元素在时间上，也是最久未被访问的）

        param:  size: 当前节点个数。
                    没有一个链表存储所有节点。 而是把节点都分开，按照频率，存放在每个小的链表里
        param:  capacity: 总容量 
        """
        self.capacity = capacity
        self.map = {}
        self.freq_map = {}
        self.size = 0


    def __update_freq(self, node):
        """ 更新结点频率 """
        freq = node.freq
        # 将原来所以在的频率链表的节点删除
        node = self.freq_map[freq].remove(node)
        if self.freq_map[freq].cur_size == 0:
            del self.freq_map[freq]
        # 更新后插入到新的频率链表
        freq += 1
        node.freq = freq
        if freq not in self.freq_map: 
            self.freq_map[freq] = DoubleLinkedList()
        # 向链表尾部添加
        self.freq_map[freq].append(node)  




    def get(self, key):
        node = self.map.get(key, None)
        if not node: return -1
        # 找到节点，更新节点的频率
        self.__update_freq(node)

        return node.value

    
    def put(self, key, value):
        node = self.map.get(key, None)
        # 缓存命中的时候
        if node: 
            node.value = value
            self.__update_freq(node)
        # 缓存不命中
        else:
            # 容量已满，需要淘汰
            if self.size >= self.capacity:
                min_freq = min(self.freq_map)
                # 从头部去掉
                node = self.freq_map[min_freq].pop()
                del self.map[node.key]
                self.size -= 1
            node = LFUNode(key, value)
            node.freq = 1
            self.map[key] = node
            # 如果不存在该频率（1频率） 的链表，那就创建一个先
            if node.freq not in self.freq_map:
                self.freq_map[node.freq] = DoubleLinkedList()
            # 尾部添加
            node = self.freq_map[node.freq].append(node)  
            self.size += 1
    
    def print(self):
        print('-'*50)
        for k, v in self.freq_map.items():
            print('freq = %d' % k)
            v.print()
        print('$'*50)


# 测试代码
if __name__ == "__main__":
    cache = LFUCache(2)
    cache.put(1, 1)
    cache.print()
    cache.put(2, 2)
    cache.print()
    print(cache.get(1))
    cache.print()
    cache.put(3, 3)
    cache.print()
    print(cache.get(2))
    cache.print()
    print(cache.get(3))
    cache.print()
    cache.put(4, 4)
    cache.print()
    print(cache.get(1))
    cache.print()
    print(cache.get(3))
    cache.print()
    print(cache.get(4))
    cache.print()


    